[PATCH] [mlir][scf][bufferize] Fix bug in WhileOp analysis verification
authorMatthias Springer <me@m-sp.org>
Mon, 15 May 2023 13:39:35 +0000 (15:39 +0200)
committerGianfranco Costamagna <locutusofborg@debian.org>
Thu, 7 Sep 2023 22:43:45 +0000 (00:43 +0200)
Block arguments and yielded values are not equivalent if there are not enough block arguments. This fixes #59442.

Differential Revision: https://reviews.llvm.org/D145575

Gbp-Pq: Name CVE-2023-29933.patch

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

index 9459640dad02b6667a3c97a807311145da324d4e..13f40aa57945a552a9893d57011c308bae89e6c8 100644 (file)
@@ -823,10 +823,12 @@ struct WhileOpInterface
 
     auto conditionOp = whileOp.getConditionOp();
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
+      Block *block = conditionOp->getBlock();
       if (!it.value().getType().isa<TensorType>())
         continue;
-      if (!state.areEquivalentBufferizedValues(
-              it.value(), conditionOp->getBlock()->getArgument(it.index())))
+      if (it.index() >= block->getNumArguments() ||
+          !state.areEquivalentBufferizedValues(it.value(),
+                                               block->getArgument(it.index())))
         return conditionOp->emitError()
                << "Condition arg #" << it.index()
                << " is not equivalent to the corresponding iter bbArg";
@@ -834,10 +836,12 @@ struct WhileOpInterface
 
     auto yieldOp = whileOp.getYieldOp();
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
+      Block *block = yieldOp->getBlock();
       if (!it.value().getType().isa<TensorType>())
         continue;
-      if (!state.areEquivalentBufferizedValues(
-              it.value(), yieldOp->getBlock()->getArgument(it.index())))
+      if (it.index() >= block->getNumArguments() ||
+          !state.areEquivalentBufferizedValues(it.value(),
+                                               block->getArgument(it.index())))
         return yieldOp->emitError()
                << "Yield operand #" << it.index()
                << " is not equivalent to the corresponding iter bbArg";
index 140f67b7c30241bca0081d44a2a5f2537ab33c0d..a4d2818e911f10e66e1e0b2d56e5cfcfe0d26269 100644 (file)
@@ -314,3 +314,17 @@ func.func @destination_passing_style_dominance_test_2(%cst : f32, %idx : index,
   %r = tensor.extract %2[%idx2] : tensor<?xf32>
   return %r : f32
 }
+
+// -----
+
+func.func @regression_scf_while() {
+  %false = arith.constant false
+  %8 = bufferization.alloc_tensor() : tensor<10x10xf32>
+  scf.while (%arg0 = %8) : (tensor<10x10xf32>) -> () {
+    scf.condition(%false)
+  } do {
+    // expected-error @+1 {{Yield operand #0 is not equivalent to the corresponding iter bbArg}}
+    scf.yield %8 : tensor<10x10xf32>
+  }
+  return
+}